import importlib.util
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def build_shiftshare_panel(root: Path, baseline_year: int) -> pd.DataFrame:
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")
    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")
    reg["broadband_pc_hh"] = pd.to_numeric(reg["broadband_pc_hh"], errors="coerce")

    base = reg[reg["year"] == baseline_year][["region", "broadband_pc_hh"]].rename(
        columns={"broadband_pc_hh": f"bb_access_base_{baseline_year}"}
    )
    base["cntry"] = base["region"].str[:2]
    base = base.dropna(subset=[f"bb_access_base_{baseline_year}"]).copy()
    base[f"bb_access_base_{baseline_year}_ct_centered"] = base[f"bb_access_base_{baseline_year}"] - base.groupby("cntry")[
        f"bb_access_base_{baseline_year}"
    ].transform("mean")
    base[f"z_bb_access_base_{baseline_year}_ct_centered"] = zscore(base[f"bb_access_base_{baseline_year}_ct_centered"])

    cbs = pd.read_csv(root / "data_external" / "broadband_country_year_eurostat_isoc_cbs.csv")
    cbs["year"] = pd.to_numeric(cbs["year"], errors="coerce").astype("Int64")
    cbs["coverage_pc_hh"] = pd.to_numeric(cbs["coverage_pc_hh"], errors="coerce")
    cbs = cbs.dropna(subset=["cntry", "year", "inet_spd"]).copy()
    wide = (
        cbs.pivot_table(index=["cntry", "year"], columns="inet_spd", values="coverage_pc_hh", aggfunc="mean")
        .reset_index()
        .rename_axis(None, axis=1)
    )
    for col in ["MBPS_GT30", "MBPS_GT100"]:
        if col not in wide.columns:
            wide[col] = np.nan

    base_ct = wide[wide["year"] == baseline_year][["cntry", "MBPS_GT30", "MBPS_GT100"]].rename(
        columns={"MBPS_GT30": f"cov30_base_{baseline_year}", "MBPS_GT100": f"cov100_base_{baseline_year}"}
    )
    wide = wide.merge(base_ct, on="cntry", how="left")
    wide["shock_cov30"] = wide["MBPS_GT30"] - wide[f"cov30_base_{baseline_year}"]
    wide["shock_cov100"] = wide["MBPS_GT100"] - wide[f"cov100_base_{baseline_year}"]

    cbt = pd.read_csv(root / "data_external" / "broadband_country_year_eurostat_isoc_cbt.csv")
    cbt["year"] = pd.to_numeric(cbt["year"], errors="coerce").astype("Int64")
    cbt["coverage_pc_hh"] = pd.to_numeric(cbt["coverage_pc_hh"], errors="coerce")
    cbt = cbt.dropna(subset=["cntry", "year", "inet_tec"]).copy()
    cbt_w = (
        cbt.pivot_table(index=["cntry", "year"], columns="inet_tec", values="coverage_pc_hh", aggfunc="mean")
        .reset_index()
        .rename_axis(None, axis=1)
    )
    if "FTTP" not in cbt_w.columns:
        cbt_w["FTTP"] = np.nan
    fttp_base = cbt_w[cbt_w["year"] == baseline_year][["cntry", "FTTP"]].rename(columns={"FTTP": f"fttp_base_{baseline_year}"})
    cbt_w = cbt_w.merge(fttp_base, on="cntry", how="left")
    cbt_w["shock_fttp"] = cbt_w["FTTP"] - cbt_w[f"fttp_base_{baseline_year}"]

    # Build panel for ESS years strictly after baseline_year.
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet", columns=["survey_year", "region", "regunit"])
    years = sorted(int(y) for y in ess["survey_year"].dropna().unique().tolist())
    years = [y for y in years if y > baseline_year]
    regions = sorted(base["region"].unique().tolist())
    panel = pd.MultiIndex.from_product([regions, years], names=["region", "survey_year"]).to_frame(index=False)
    panel["cntry"] = panel["region"].str[:2]
    panel = panel.merge(base[["region", "cntry", f"z_bb_access_base_{baseline_year}_ct_centered"]], on=["region", "cntry"], how="left")
    panel = panel.merge(
        wide[["cntry", "year", "shock_cov30", "shock_cov100"]].rename(columns={"year": "survey_year"}),
        on=["cntry", "survey_year"],
        how="left",
    )
    panel = panel.merge(
        cbt_w[["cntry", "year", "shock_fttp"]].rename(columns={"year": "survey_year"}),
        on=["cntry", "survey_year"],
        how="left",
    )

    exp = panel[f"z_bb_access_base_{baseline_year}_ct_centered"]
    panel["ss_cov30"] = exp * panel["shock_cov30"]
    panel["ss_cov100"] = exp * panel["shock_cov100"]
    panel["ss_fttp"] = exp * panel["shock_fttp"]
    panel["z_ss_cov30"] = zscore(panel["ss_cov30"])
    panel["z_ss_cov100"] = zscore(panel["ss_cov100"])
    panel["z_ss_fttp"] = zscore(panel["ss_fttp"])
    panel["baseline_year"] = baseline_year
    return panel[["region", "survey_year", "z_ss_cov30", "z_ss_cov100", "z_ss_fttp", "baseline_year"]]


def fit_iv_with_panel(mod, df: pd.DataFrame, y: str, controls: list[str], zcols: list[str]) -> IV2SLS:
    cols = [y, "netusoft", "edu_high", *zcols, "ct", "region", *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")
    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    for z in zcols:
        d[f"{z}_x_edu"] = d[z] * d["edu_high"]

    model_cols = [y, "netusoft", "netusoft_x_edu", *zcols, *[f"{z}_x_edu" for z in zcols], *controls, "edu_high"]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    exog = r[controls + ["edu_high"]]
    endog = r[["netusoft", "netusoft_x_edu"]]
    instr = r[zcols + [f"{z}_x_edu" for z in zcols]]
    return IV2SLS(y_r, exog=exog, endog=endog, instruments=instr).fit(cov_type="clustered", clusters=d["region"])


def to_latex_tabular(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
        "AR_pvalue",
    ]
    header = [
        "Spec",
        "N",
        "b(LowEdu)",
        "se(LowEdu)",
        "b(HighEdu)",
        "Delta(High-Low)",
        "se(Delta)",
        "F(net)",
        "F(net$\\times$edu)",
        "AR p",
    ]
    d = df[cols].copy()
    d["nobs"] = d["nobs"].map(lambda x: f"{x:.0f}" if pd.notna(x) else "")
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    d["AR_pvalue"] = d["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")

    out = []
    out.append("\\begin{tabular}{lrrrrrrrrr}")
    out.append("\\toprule")
    out.append(" & ".join(header) + " \\\\")
    out.append("\\midrule")
    for _, row in d.iterrows():
        out.append(" & ".join(str(row[c]) for c in cols) + " \\\\")
    out.append("\\bottomrule")
    out.append("\\end{tabular}")
    return "\n".join(out)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)
    paper_tables = root / "paper_joc" / "tables"
    paper_tables.mkdir(parents=True, exist_ok=True)

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)
    controls = ["agea", "gndr", "hinctnta"]

    # Baseline 2019 sensitivity: restrict to post-2019 ESS years.
    panel_2019 = build_shiftshare_panel(root, baseline_year=2019)
    df2 = df.merge(panel_2019, on=["region", "survey_year"], how="left", suffixes=("", "_b2019"))
    df2 = df2[df2["survey_year"].isin([2021, 2024])].copy()
    df2 = df2.rename(columns={"z_ss_cov30_b2019": "z_ss_cov30_2019", "z_ss_cov100_b2019": "z_ss_cov100_2019", "z_ss_fttp_b2019": "z_ss_fttp_2019"})
    # In case suffixes didn't apply (pandas may keep original if no conflict), standardize.
    for base, new in [("z_ss_cov30", "z_ss_cov30_2019"), ("z_ss_cov100", "z_ss_cov100_2019"), ("z_ss_fttp", "z_ss_fttp_2019")]:
        if new not in df2.columns and base in df2.columns:
            df2[new] = df2[base]

    rows = []
    for y, ylab in [("y_part_index", "Participation index"), ("y_vote", "Vote (yes/no)")]:
        # Baseline 2016 (main spec) but restricted to 2021/2024 for comparability
        d_main = df[df["survey_year"].isin([2021, 2024])].copy()
        res_2016 = mod.fit_iv_interaction_absorb(d_main, y, controls=controls)
        rows.append(mod.summarize_result(res_2016, y=y, label=f"Baseline 2016 (years 2021/2024): {ylab}"))

        # Baseline 2019 alternative
        res_2019 = fit_iv_with_panel(mod, df2, y=y, controls=controls, zcols=["z_ss_cov30_2019", "z_ss_cov100_2019", "z_ss_fttp_2019"])
        rows.append(mod.summarize_result(res_2019, y=y, label=f"Baseline 2019 (years 2021/2024): {ylab}"))

    tab = pd.DataFrame(rows)
    (out_dir / "iv_shiftshare_baseline_sensitivity.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_baseline_sensitivity.md").write_text(mod.to_markdown_table(tab) + "\n", encoding="utf-8")
    (paper_tables / "iv_shiftshare_baseline_sensitivity.tex").write_text(to_latex_tabular(tab) + "\n", encoding="utf-8")
    print("Wrote baseline sensitivity table to paper_joc/tables/iv_shiftshare_baseline_sensitivity.tex")


if __name__ == "__main__":
    main()

